// etpredict.cpp - 2016 - Atlee Brink
// extremely randomized trees model applicator / predictor

#include "nexamples.hpp"
#include "nextratrees.hpp"

#include <iostream>
#include <set>
#include <string>
#include <vector>

namespace {
    using namespace std;
    using namespace nexamples;
    using namespace nextratrees;

    const string executablename = "etpredict";

    const string MANDATORY_PARAMETERS = "-t <in:testfile.csv> -m <in:modelfile>"
        " [other optional parameters]";
    const string PARAMETER_DESCRIPTIONS[] = {
        "\t-a"
            "\n\t\tcompute accuracy"
            "\n\t\tNOTE: requires that testfile.csv contains a label column",
        "\t-i <index_column_name>"
            "\n\t\tname of the index column to output with predictions"
            "\n\t\tNOTE: this column must have been excluded with -e during training",
        "\t-m <in:modelfile>"
            "\n\t\tspecify input model file, as generated by lrtrain",
        "\t-p <out:predictions>"
            "\n\t\tspecify output predictions file",
        "\t-t <in:testfile.csv>"
            "\n\t\tspecify input testing data file, in comma-separated-value format"
    };

    bool computeaccuracy = false;
    string indexcolumn;
    string testfile;
    string modelfile;
    string predictionsfile;

    void
    showusage() {
        cout << "usage: " << executablename << " " << MANDATORY_PARAMETERS << "\n\n";
        for( auto &desc : PARAMETER_DESCRIPTIONS ) cout << desc << "\n\n";
    }

    bool
    processarguments( int argc, char *argv[] ) {
        // need at least 4 arguments: -t testfile and -m modelfile
        if( argc < 1 + 2 + 2 ) return false;

        // process command line
        for( int argi = 1; argi < argc; argi++ ) {
            string arg( argv[ argi ] );

            bool good = false;

            // check for argument pairs
            if( argi + 1 < argc ) {
                string nextarg( argv[ argi + 1 ] );

                if( arg == "-i" ) { // index column name
                    indexcolumn = nextarg;
                    good = true;
                }
                else if( arg == "-m" ) { // modelfile (input)
                    modelfile = nextarg;
                    good = true;
                }
                else if( arg == "-p" ) { // predictionsfile (output)
                    predictionsfile = nextarg;
                    good = true;
                }
                else if( arg == "-t" ) { // testfile (input)
                    testfile = nextarg;
                    good = true;
                }

                if( good ) {
                    argi++;
                    continue;
                }
            }

            // check for argument singles
            if( arg == "-a" ) { // compute accuracy
                computeaccuracy = true;
                good = true;
            }

            if( !good ) {
                cout << "unrecognized command line parameter: " << arg << "\n\n";
                return false;
            }
        }

        // check that at least testfile and modelfile were set
        bool isgood = true;
        if( testfile.empty() ) {
            cout << "command line parameter needed: -t <in:testfile.csv>\n";
            isgood = false;
        }
        if( modelfile.empty() ) {
            cout << "command line parameter needed: -m <in:modelfile>\n";
            isgood = false;
        }

        // check that at least one of -a or -p were used
        if( !computeaccuracy && predictionsfile.empty() ) {
            cout << "nothing to do: I suggest -a or -p <out:predictionsfile>\n";
            isgood = false;
        }
        
        if( !isgood ) cout << "\n";

        return isgood;
    }
}

int main( int argc, char *argv[] ) {

    cout << "extremely randomized trees predictor, coded by Atlee Brink\n\n";

    if( !processarguments( argc, argv ) ) {
        showusage();
        return 0;
    }

    // all the stuff that goes in a model
    string labelname;
    vector< string > exnames;
    vector< string > featurenames;
    nextratrees::forest_t forest;
    size_t nmin, numattr, optimizationlayers;

    // load model from file
    cout << "loading model..." << flush;
    if( !nextratrees::loadmodelfromfile(
        modelfile,
        labelname,
        exnames,
        featurenames,
        forest,
        nmin,
        numattr,
        optimizationlayers )
    ) {
        cerr << "failed to load model from: " << modelfile << endl;
        return 1;
    }
    cout << "done" << endl;

    // convert vector exnames to set
    set< string > excludedfeatures;
    for( auto &name : exnames ) excludedfeatures.insert( name );

    // load test set from file
    cout << "loading test set..." << flush;
    cexampleset testset;
    bool islabeled = false;
    if( !testset.loadfromfile(
        testfile,
        labelname,
        islabeled, // in: says we don't care, out: says whether it was found
        excludedfeatures )
    ) {
        cerr << "failed to load test set from: " << testfile << endl;
        return 1;
    }
    cout << "done" << endl;

    // find index column index
    size_t indexcolumnindex = 0;
    bool usingindex = false;
    if( !indexcolumn.empty() ) {
        while( testset.exnames[ indexcolumnindex ] != indexcolumn ) {
            indexcolumnindex++;
        }
        if( indexcolumnindex < testset.exnames.size() ) {
            usingindex = true;
        }
    }

    // if asked to compute accuracy, check that the data is labeled
    if( computeaccuracy && !islabeled ) {
        cerr << "asked to compute accuracy, but data is not labeled\n";
        return 1;
    }

    // predict labels
    cout << "predicting..." << flush;
    bool dostorepredictions = !predictionsfile.empty();
    ofstream outfile; // may go unused
    if( dostorepredictions ) {
        outfile.open( predictionsfile );
        if( !outfile ) {
            cerr << "error creating output file: " << predictionsfile << endl;
            return false;
        }
        if( usingindex ) outfile << indexcolumn << ",";
        outfile << labelname << "\n";
    }
    size_t numcorrect = 0; // may go unused
    auto Ii = testset.exfeaturevectors.cbegin();
    auto Yi = testset.labels.cbegin(); // may be empty
    for( auto &X : testset.featurevectors ) {

        auto predictedlabel = forest.classify( X );

        if( computeaccuracy ) numcorrect += predictedlabel == *Yi++;
        if( dostorepredictions ) {
            if( usingindex ) outfile << (*Ii++)[indexcolumnindex] << ",";
            outfile << predictedlabel << "\n";
        }

    }
    cout << "done" << endl;

    if( computeaccuracy ) {
        double accuracy = 100.0 * (double)numcorrect / testset.featurevectors.size();
        cout << "accuracy on test set: " << accuracy << "%\n";
    }

    return 0;
}
